import torch
import torch.nn as nn
import torch.optim as optim
from src.data.data_utils import choose_dataset
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--m", dest="model", type=str, default='resnet')
parser.add_argument("--d", dest="data", type=str, default='cifar10')
options = parser.parse_args()


# Set the device (GPU or CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'use device {device}')


# Define the hyperparameters
batch_size = 128
learning_rate = 0.05
num_epochs = 100
momentum = 0.9
weight_decay = 5e-4

if options.data.lower() == 'tiny_imagenet':   ### training hyperparameters for tiny imagenet
    learning_rate = 0.01
    num_epochs = 50
    momentum = 0.9
    weight_decay = 5e-3
    batch_size = 128


# -------- Dataset Selection -----------
if options.data.lower() == 'imagenet':
    num_classes = 1000
    datapath = "./imageNet/"
elif options.data.lower() == 'tiny_imagenet':
    num_classes = 200
    datapath = "./tiny_imagenet/"
else:
    num_classes = 10
    datapath = "./data/"

train_loader, val_loader, test_loader = choose_dataset(dataset_name=options.data.lower(), batch_size=batch_size,
                                                       datapath=datapath)

if 'resnet' in options.model.lower():
    import src.low_rank_neural_networks.baselines.ResNet as baseline
elif 'vgg' in options.model.lower():
    import src.low_rank_neural_networks.baselines.VGG as baseline
elif 'alexnet' in options.model.lower():
    import src.low_rank_neural_networks.baselines.alexnet as baseline

# Initialize the model
if options.model.lower() == 'vgg':
    model = baseline.vgg16().to(device)
    print("Train VGG16")
elif options.model.lower() == 'resnet':
    model = baseline.resnet18(num_classes = num_classes).to(device)
    print("Train ResNet18")
elif options.model.lower() == 'resnet50':
    model = baseline.resnet50().to(device)
    print("Train ResNet50")
elif options.model.lower() == 'alexnet':
    model = baseline.alexnet().to(device)
    print("Train AlexNet")

# print(f'dataset name {name}')
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 40], gamma=0.1)
if options.data.lower() == 'tiny_imagenet':
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.02, steps_per_epoch=len(train_loader),
                       epochs=num_epochs, div_factor=10, final_div_factor=10,
                       pct_start=10/num_epochs)

total_params = 0
with torch.no_grad():
    for p in model.parameters():
        total_params += int(torch.prod(torch.tensor(p.shape)))

print(f'total_params {total_params}')

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)
        for p in model.parameters():
            p.grad = None
        # Forward pass
        outputs = model(data)
        loss = criterion(outputs, targets)

        # Backward
        loss.backward()
        optimizer.step()
        total_loss += float(loss.item())
        if (batch_idx + 1) % 100 == 0:
            print(
                f"Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")

    scheduler.step()  # (total_loss)
    # Test the model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, targets in val_loader:
            data = data.to(device)
            targets = targets.to(device)

            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        accuracy = 100 * correct / total
        print(f"Accuracy of the network on the test images: {accuracy}%")

print("Training finished.")
